import os
import sys
import argparse
import json
import numpy as np
if __name__=='__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument("--test_data",default="../maven/test.jsonl",help="path to the test data file",required=False)
    parser.add_argument("--preds",default="MAVEN/_preds.npy",help="path to the prediction file generated by the run_MAVEN_infer.sh script")
    parser.add_argument("--output",default="../maven/results.jsonl",help="path to the output file")
    args=parser.parse_args()
    preds=np.load(args.preds)
    fout=open(args.output,"w")
    with open(args.test_data,"r") as fin:
        lines=fin.readlines()
        Cnt=0
        for line in lines:
            data=json.loads(line)
            res={"id":data['id']}
            tmp=[]
            for mention in data['candidates']:
                tmp.append({"id":mention["id"],"type_id":int(preds[Cnt])})
                Cnt+=1
            res["predictions"]=tmp
            fout.write(json.dumps(res)+"\n")
        assert Cnt == len(preds)
    fout.close()
